{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/neild0/StyleDrop-PyTorch-Interactive/blob/main/styledrop_colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# StyleDrop Demo\n",
        "\n",
        "<p align=\"left\">\n",
        "  <a href=\"https://huggingface.co/spaces/zideliu/styledrop\"><img alt=\"Huggingface\" src=\"https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-StyleDrop-orange\"></a>\n",
        "  <a href=\"https://replicate.com/cjwbw/styledrop\"><img src=\"https://replicate.com/cjwbw/styledrop/badge\"></a>\n",
        "</p>\n",
        "\n",
        "This is an unofficial PyTorch implementation of [StyleDrop: Text-to-Image Generation in Any Style](https://arxiv.org/abs/2306.00983), based on [zideliu's implementation](https://github.com/zideliu/StyleDrop-PyTorch).\n",
        "\n",
        "<img src=\"https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/main/img/1.png\"/>\n",
        "<img src=\"https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/main/img/2.png\"/>\n",
        "<img src=\"https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/main/img/5.png\"/>\n"
      ],
      "metadata": {
        "id": "IBNlh6XSHLRq"
      },
      "id": "IBNlh6XSHLRq"
    },
    {
      "cell_type": "markdown",
      "source": [
        "**You need Colab Pro to train this model, with A100 enabled (~20Gb GPU RAM).**"
      ],
      "metadata": {
        "id": "VKEoL0s8IGXr"
      },
      "id": "VKEoL0s8IGXr"
    },
    {
      "cell_type": "code",
      "source": [
        "!nvidia-smi"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "xwU_dnSGIE0o",
        "outputId": "e6b9b838-ead6-4b11-8965-454598a66a3c"
      },
      "id": "xwU_dnSGIE0o",
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Sun Jul  9 07:22:20 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  NVIDIA A100-SXM...  Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   31C    P0    48W / 400W |      0MiB / 40960MiB |      0%      Default |\n",
            "|                               |                      |             Disabled |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "|  No running processes found                                                 |\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Setup"
      ],
      "metadata": {
        "id": "pfnc1jJoEMp3"
      },
      "id": "pfnc1jJoEMp3"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "546a1e28-bbe1-4cac-832c-090d58aa4b5f",
      "metadata": {
        "id": "546a1e28-bbe1-4cac-832c-090d58aa4b5f"
      },
      "outputs": [],
      "source": [
        "!git clone https://github.com/neild0/StyleDrop-PyTorch-Interactive\n",
        "!mv StyleDrop-PyTorch-Interactive StyleDrop-PyTorch\n",
        "\n",
        "!pip install omegaconf gdown accelerate==0.12.0 absl-py ml_collections einops wandb ftfy==6.1.1 transformers==4.23.1 loguru webdataset==0.2.5 gradio xformers"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!git lfs install\n",
        "!git clone https://huggingface.co/nzl-thu/MUSE\n",
        "!mv MUSE/assets/* StyleDrop-PyTorch/assets/\n",
        "!rm -rf MUSE"
      ],
      "metadata": {
        "id": "0bpgLXkHcg8v"
      },
      "id": "0bpgLXkHcg8v",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!gdown --id '13S_unB87n6KKuuMdyMnyExW0G1kplTbP' --output StyleDrop-PyTorch/assets/vqgan_jax_strongaug.ckpt"
      ],
      "metadata": {
        "id": "IJZbaAOVY3_3"
      },
      "id": "IJZbaAOVY3_3",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!python StyleDrop-PyTorch/extract_empty_feature.py"
      ],
      "metadata": {
        "id": "sVInHq6dbAa0"
      },
      "id": "sVInHq6dbAa0",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "#Set Style"
      ],
      "metadata": {
        "id": "xXcfSxdoHHqa"
      },
      "id": "xXcfSxdoHHqa"
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Upload Style File\n",
        "\n",
        "import shutil\n",
        "from google.colab import files\n",
        "import os\n",
        "from IPython.display import Image, display\n",
        "\n",
        "# specify the directory you want to move the file to\n",
        "directory = \"/content/StyleDrop-PyTorch/data\"\n",
        "\n",
        "# get all the files in the directory\n",
        "uploaded = files.upload()\n",
        "\n",
        "if len(uploaded.keys()) != 1:\n",
        "    raise Exception(\"Please only select one file\")\n",
        "\n",
        "for fn in uploaded.keys():\n",
        "\n",
        "    extension = os.path.splitext(fn)[1]\n",
        "\n",
        "    if extension == '.jpg':\n",
        "        # move the file and get the new path\n",
        "        new_path = shutil.move(fn, f\"{directory}/{fn.replace(' ', '_')}\")\n",
        "\n",
        "        # display the image\n",
        "        display(Image(filename=new_path))\n",
        "\n",
        "    else:\n",
        "        raise Exception(\"Please only select jpg files\")\n"
      ],
      "metadata": {
        "cellView": "form",
        "id": "AYqDPmacEWX2"
      },
      "id": "AYqDPmacEWX2",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Describe style image\n",
        "import json\n",
        "\n",
        "object_in_image = \"a woman walking a dog\" #@param {type:\"string\"}\n",
        "style_of_image = \"flat cartoon\" #@param {type:\"string\"}\n",
        "\n",
        "file_name = os.path.basename(new_path)\n",
        "one_style = {file_name: [object_in_image, f\"in {style_of_image} style\"]}\n",
        "with open(\"/content/StyleDrop-PyTorch/data/one_style.json\", \"w\") as f:\n",
        "    json.dump(one_style, f)"
      ],
      "metadata": {
        "cellView": "form",
        "id": "-YiFaov_EvHW"
      },
      "id": "-YiFaov_EvHW",
      "execution_count": 12,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Train model"
      ],
      "metadata": {
        "id": "zE9WB5BRF9JM"
      },
      "id": "zE9WB5BRF9JM"
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import subprocess\n",
        "import glob\n",
        "\n",
        "if 'EVAL_CKPT' in os.environ:\n",
        "    del os.environ['EVAL_CKPT']\n",
        "\n",
        "if 'ADAPTER' in os.environ:\n",
        "    del os.environ['ADAPTER']\n",
        "\n",
        "style_output_folder = style_of_image.replace(\" \", \"_\")\n",
        "\n",
        "os.environ['OUTPUT_DIR'] = style_output_folder\n",
        "\n",
        "!cd StyleDrop-PyTorch && accelerate launch --mixed_precision fp16 train_t2i_colab_v2.py --config=configs/custom.py\n",
        "\n",
        "# get a list of all subdirectories\n",
        "subdirs = glob.glob(f\"/content/StyleDrop-PyTorch/{style_output_folder}/ckpts_II/\" + '*/')\n",
        "subdirs.sort()\n",
        "\n",
        "last_subdir = subdirs[-1]\n",
        "adapter = f\"{last_subdir}/adapter.pth\"\n",
        "\n",
        "new_adapter_path = f\"/content/StyleDrop-PyTorch/style_adapter/{style_output_folder}\"\n",
        "os.makedirs(new_adapter_path, exist_ok=True)\n",
        "os.rename(adapter, f\"{new_adapter_path}/adapter.pth\")\n"
      ],
      "metadata": {
        "id": "epmOxNBY4BbL"
      },
      "id": "epmOxNBY4BbL",
      "execution_count": 15,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Run Gradio Interface"
      ],
      "metadata": {
        "id": "XPw0Fs3wHAb6"
      },
      "id": "XPw0Fs3wHAb6"
    },
    {
      "cell_type": "code",
      "source": [
        "!cd StyleDrop-PyTorch && python3 gradio_demo.py"
      ],
      "metadata": {
        "id": "dv3ULWt4wDqH"
      },
      "id": "dv3ULWt4wDqH",
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.16"
    },
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "A100",
      "include_colab_link": true
    },
    "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 5
}